{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "This is a companion notebook for the book [Deep Learning with Python, Second Edition](https://www.manning.com/books/deep-learning-with-python-second-edition?a_aid=keras&a_bid=76564dff). For readability, it only contains runnable code blocks and section titles, and omits everything else in the book: text paragraphs, figures, and pseudocode.\n\n**If you want to be able to follow what's going on, I recommend reading the notebook side by side with your copy of the book.**\n\nThis notebook was generated for TensorFlow 2.6."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "# Generative deep learning"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "## Text generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### A brief history of generative deep learning for sequence generation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### How do you generate sequence data?"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### The importance of the sampling strategy"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**Reweighting a probability distribution to a different temperature**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "def reweight_distribution(original_distribution, temperature=0.5):\n",
    "    distribution = np.log(original_distribution) / temperature\n",
    "    distribution = np.exp(distribution)\n",
    "    return distribution / np.sum(distribution)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Implementing text generation with Keras"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### Preparing the data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**Downloading and uncompressing the IMDB movie reviews dataset**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "!wget https://ai.stanford.edu/~amaas/data/sentiment/aclImdb_v1.tar.gz\n",
    "!tar -xf aclImdb_v1.tar.gz"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**Creating a dataset from text files (one file = one sample)**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow import keras\n",
    "dataset = keras.utils.text_dataset_from_directory(\n",
    "    directory=\"aclImdb\", label_mode=None, batch_size=256)\n",
    "dataset = dataset.map(lambda x: tf.strings.regex_replace(x, \"<br />\", \" \"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**Preparing a `TextVectorization` layer**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras.layers import TextVectorization\n",
    "\n",
    "sequence_length = 100\n",
    "vocab_size = 15000\n",
    "text_vectorization = TextVectorization(\n",
    "    max_tokens=vocab_size,\n",
    "    output_mode=\"int\",\n",
    "    output_sequence_length=sequence_length,\n",
    ")\n",
    "text_vectorization.adapt(dataset)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**Setting up a language modeling dataset**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "def prepare_lm_dataset(text_batch):\n",
    "    vectorized_sequences = text_vectorization(text_batch)\n",
    "    x = vectorized_sequences[:, :-1]\n",
    "    y = vectorized_sequences[:, 1:]\n",
    "    return x, y\n",
    "\n",
    "lm_dataset = dataset.map(prepare_lm_dataset, num_parallel_calls=4)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "#### A Transformer-based sequence-to-sequence model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "from tensorflow.keras import layers\n",
    "\n",
    "class PositionalEmbedding(layers.Layer):\n",
    "    def __init__(self, sequence_length, input_dim, output_dim, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.token_embeddings = layers.Embedding(\n",
    "            input_dim=input_dim, output_dim=output_dim)\n",
    "        self.position_embeddings = layers.Embedding(\n",
    "            input_dim=sequence_length, output_dim=output_dim)\n",
    "        self.sequence_length = sequence_length\n",
    "        self.input_dim = input_dim\n",
    "        self.output_dim = output_dim\n",
    "\n",
    "    def call(self, inputs):\n",
    "        length = tf.shape(inputs)[-1]\n",
    "        positions = tf.range(start=0, limit=length, delta=1)\n",
    "        embedded_tokens = self.token_embeddings(inputs)\n",
    "        embedded_positions = self.position_embeddings(positions)\n",
    "        return embedded_tokens + embedded_positions\n",
    "\n",
    "    def compute_mask(self, inputs, mask=None):\n",
    "        return tf.math.not_equal(inputs, 0)\n",
    "\n",
    "    def get_config(self):\n",
    "        config = super(PositionalEmbedding, self).get_config()\n",
    "        config.update({\n",
    "            \"output_dim\": self.output_dim,\n",
    "            \"sequence_length\": self.sequence_length,\n",
    "            \"input_dim\": self.input_dim,\n",
    "        })\n",
    "        return config\n",
    "\n",
    "\n",
    "class TransformerDecoder(layers.Layer):\n",
    "    def __init__(self, embed_dim, dense_dim, num_heads, **kwargs):\n",
    "        super().__init__(**kwargs)\n",
    "        self.embed_dim = embed_dim\n",
    "        self.dense_dim = dense_dim\n",
    "        self.num_heads = num_heads\n",
    "        self.attention_1 = layers.MultiHeadAttention(\n",
    "          num_heads=num_heads, key_dim=embed_dim)\n",
    "        self.attention_2 = layers.MultiHeadAttention(\n",
    "          num_heads=num_heads, key_dim=embed_dim)\n",
    "        self.dense_proj = keras.Sequential(\n",
    "            [layers.Dense(dense_dim, activation=\"relu\"),\n",
    "             layers.Dense(embed_dim),]\n",
    "        )\n",
    "        self.layernorm_1 = layers.LayerNormalization()\n",
    "        self.layernorm_2 = layers.LayerNormalization()\n",
    "        self.layernorm_3 = layers.LayerNormalization()\n",
    "        self.supports_masking = True\n",
    "\n",
    "    def get_config(self):\n",
    "        config = super(TransformerDecoder, self).get_config()\n",
    "        config.update({\n",
    "            \"embed_dim\": self.embed_dim,\n",
    "            \"num_heads\": self.num_heads,\n",
    "            \"dense_dim\": self.dense_dim,\n",
    "        })\n",
    "        return config\n",
    "\n",
    "    def get_causal_attention_mask(self, inputs):\n",
    "        input_shape = tf.shape(inputs)\n",
    "        batch_size, sequence_length = input_shape[0], input_shape[1]\n",
    "        i = tf.range(sequence_length)[:, tf.newaxis]\n",
    "        j = tf.range(sequence_length)\n",
    "        mask = tf.cast(i >= j, dtype=\"int32\")\n",
    "        mask = tf.reshape(mask, (1, input_shape[1], input_shape[1]))\n",
    "        mult = tf.concat(\n",
    "            [tf.expand_dims(batch_size, -1),\n",
    "             tf.constant([1, 1], dtype=tf.int32)], axis=0)\n",
    "        return tf.tile(mask, mult)\n",
    "\n",
    "    def call(self, inputs, encoder_outputs, mask=None):\n",
    "        causal_mask = self.get_causal_attention_mask(inputs)\n",
    "        if mask is not None:\n",
    "            padding_mask = tf.cast(\n",
    "                mask[:, tf.newaxis, :], dtype=\"int32\")\n",
    "            padding_mask = tf.minimum(padding_mask, causal_mask)\n",
    "        else:\n",
    "            padding_mask = mask\n",
    "        attention_output_1 = self.attention_1(\n",
    "            query=inputs,\n",
    "            value=inputs,\n",
    "            key=inputs,\n",
    "            attention_mask=causal_mask)\n",
    "        attention_output_1 = self.layernorm_1(inputs + attention_output_1)\n",
    "        attention_output_2 = self.attention_2(\n",
    "            query=attention_output_1,\n",
    "            value=encoder_outputs,\n",
    "            key=encoder_outputs,\n",
    "            attention_mask=padding_mask,\n",
    "        )\n",
    "        attention_output_2 = self.layernorm_2(\n",
    "            attention_output_1 + attention_output_2)\n",
    "        proj_output = self.dense_proj(attention_output_2)\n",
    "        return self.layernorm_3(attention_output_2 + proj_output)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**A simple Transformer-based language model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "from tensorflow.keras import layers\n",
    "embed_dim = 256\n",
    "latent_dim = 2048\n",
    "num_heads = 2\n",
    "\n",
    "inputs = keras.Input(shape=(None,), dtype=\"int64\")\n",
    "x = PositionalEmbedding(sequence_length, vocab_size, embed_dim)(inputs)\n",
    "x = TransformerDecoder(embed_dim, latent_dim, num_heads)(x, x)\n",
    "outputs = layers.Dense(vocab_size, activation=\"softmax\")(x)\n",
    "model = keras.Model(inputs, outputs)\n",
    "model.compile(loss=\"sparse_categorical_crossentropy\", optimizer=\"rmsprop\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### A text-generation callback with variable-temperature sampling"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**The text-generation callback**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "tokens_index = dict(enumerate(text_vectorization.get_vocabulary()))\n",
    "\n",
    "def sample_next(predictions, temperature=1.0):\n",
    "    predictions = np.asarray(predictions).astype(\"float64\")\n",
    "    predictions = np.log(predictions) / temperature\n",
    "    exp_preds = np.exp(predictions)\n",
    "    predictions = exp_preds / np.sum(exp_preds)\n",
    "    probas = np.random.multinomial(1, predictions, 1)\n",
    "    return np.argmax(probas)\n",
    "\n",
    "class TextGenerator(keras.callbacks.Callback):\n",
    "    def __init__(self,\n",
    "                 prompt,\n",
    "                 generate_length,\n",
    "                 model_input_length,\n",
    "                 temperatures=(1.,),\n",
    "                 print_freq=1):\n",
    "        self.prompt = prompt\n",
    "        self.generate_length = generate_length\n",
    "        self.model_input_length = model_input_length\n",
    "        self.temperatures = temperatures\n",
    "        self.print_freq = print_freq\n",
    "        vectorized_prompt = text_vectorization([prompt])[0].numpy()\n",
    "        self.prompt_length = np.nonzero(vectorized_prompt == 0)[0][0]\n",
    "\n",
    "    def on_epoch_end(self, epoch, logs=None):\n",
    "        if (epoch + 1) % self.print_freq != 0:\n",
    "            return\n",
    "        for temperature in self.temperatures:\n",
    "            print(\"== Generating with temperature\", temperature)\n",
    "            sentence = self.prompt\n",
    "            for i in range(self.generate_length):\n",
    "                tokenized_sentence = text_vectorization([sentence])\n",
    "                predictions = self.model(tokenized_sentence)\n",
    "                next_token = sample_next(\n",
    "                    predictions[0, self.prompt_length - 1 + i, :]\n",
    "                )\n",
    "                sampled_token = tokens_index[next_token]\n",
    "                sentence += \" \" + sampled_token\n",
    "            print(sentence)\n",
    "\n",
    "prompt = \"This movie\"\n",
    "text_gen_callback = TextGenerator(\n",
    "    prompt,\n",
    "    generate_length=50,\n",
    "    model_input_length=sequence_length,\n",
    "    temperatures=(0.2, 0.5, 0.7, 1., 1.5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "**Fitting the language model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 0,
   "metadata": {
    "colab_type": "code"
   },
   "outputs": [],
   "source": [
    "model.fit(lm_dataset, epochs=200, callbacks=[text_gen_callback])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text"
   },
   "source": [
    "### Wrapping up"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "collapsed_sections": [],
   "name": "chapter12_part01_text-generation.i",
   "private_outputs": false,
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}